import pytest

from dada import find_model_runner
from pathlib import Path

test_plot_path = 'tests/plot'


def test_invalid_model():
    params = {
        'vector_size': 1000,
        'radius': 1000000,
        'num_polyhedron': 10000,
        'mu_list': None,
        'q_list': [1, 1.25, 1.5, 1.75, 2],
        'p_list': None,
    }
    model = "A"
    with pytest.raises(ValueError):
        find_model_runner(model, params)


def test_polynomial_feasibility_problem():
    params = {
        'vector_size': 1000,
        'radius': 1000000,
        'num_polyhedron': 10000,
        'mu_list': None,
        'q_list': [1, 1.25, 1.5, 1.75, 2],
        'p_list': None,
    }
    model = "PF"
    model_name = 'polynomial-feasibility'
    model_runner = find_model_runner(model, params)
    model_runner.run(10, model_name, True, test_plot_path)

    check_plots(model_name)


def test_norm_model():
    params = {
        'vector_size': 100,
        'radius': None,
        'num_polyhedron': None,
        'mu_list': None,
        'q_list': None,
        'p_list': [1.2, 1.4, 2, 4, 6],
    }
    model = "NORM"
    model_name = 'norm'
    model_runner = find_model_runner(model, params)
    model_runner.run(10, model_name, True, test_plot_path)

    check_plots(model_name)


def test_log_sum_exp_model():
    params = {
        'vector_size': 100,
        'radius': 1,
        'num_polyhedron': 1000,
        'mu_list': [1, 0.5, 0.1, 0.05],
        'q_list': None,
        'p_list': None,
    }
    model = "LSE"
    model_name = 'log-sum-exp'
    model_runner = find_model_runner(model, params)
    model_runner.run(10, model_name, True, test_plot_path)

    check_plots(model_name)


def check_plots(model_name):
    estimate_error_plot = Path(f'./{test_plot_path}/comparison-{model_name}-estimate-error.pdf')
    assert estimate_error_plot.exists(), f"File {estimate_error_plot} does not exist."
    # Cleanup after test
    estimate_error_plot.unlink()

    residual_plot = Path(f'./{test_plot_path}/comparison-{model_name}-residual.pdf')
    assert residual_plot.exists(), f"File {residual_plot} does not exist."
    # Cleanup after test
    residual_plot.unlink()
